{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import torch, json, numpy\n",
    "from netdissect import proggan, nethook, easydict, zdataset\n",
    "from netdissect.plotutil import plot_tensor_images, plot_max_heatmap\n",
    "from scipy.stats import wasserstein_distance\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create the outdoor church model and instrument all the layers past layer4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = proggan.from_pth_file('models/karras/churchoutdoor_lsun.pth').to(device)\n",
    "nethook.edit_layers(model, ['layer4'])\n",
    "instrumented_layers = ['layer%d' % i for i in range(4, 15)] + ['output_256x256']\n",
    "nethook.retain_layers(model, instrumented_layers)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Select door units, and visualize them on two different images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dissect = easydict.load_json('dissect/churchoutdoor/dissect.json')\n",
    "lrec = next(x for x in dissect.layers if x.layer == 'layer4')\n",
    "rrec = next(x for x in lrec.rankings if x.name == 'door-iou')\n",
    "units = torch.from_numpy(numpy.argsort(rrec.score)[:20])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "zbatch = zdataset.z_sample_for_model(model, 350)[[13,34,109,134,139,346]].to(device)\n",
    "plot_tensor_images(model(zbatch))\n",
    "plot_max_heatmap(model.retained['layer4'][:,units], shape=(256,256))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the maximum activation of each of the door features in the first image as the `canonical door`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model(zbatch)\n",
    "# door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]\n",
    "door_feature = model.retained['layer4'][0][units][:,5,3]\n",
    "baseline_target = model.retained['layer4'][1:].clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create modification #1: put a door in the church wall.\n",
    "\n",
    "How does this change become RGB?  Trace through the layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_target = baseline_target.clone()\n",
    "modified_target[:,units,7,4] = door_feature\n",
    "model.ablation['layer4'] = torch.tensor([1,0])[:,None,None,None]\n",
    "model.replacement['layer4'] = modified_target\n",
    "plot_tensor_images(model(zbatch))\n",
    "model.ablation['layer4'] = None\n",
    "model.replacement['layer4'] = None\n",
    "graph1 = []\n",
    "for layer in instrumented_layers:\n",
    "    features = model.retained[layer]\n",
    "    changed_features = features[0] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]\n",
    "    base_features = features[1] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]\n",
    "    diff = changed_features - base_features\n",
    "    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))\n",
    "                   for c, b in zip(changed_features, base_features)])\n",
    "    norm = wd.mean()\n",
    "    print(layer, norm.item())\n",
    "    graph1.append(norm.item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create modification #2: try to put a door in the sky"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_target_2 = baseline_target.clone()\n",
    "modified_target_2[:,units,5,6] = door_feature\n",
    "model.ablation['layer4'] = torch.tensor([1,0])[:,None,None,None]\n",
    "model.replacement['layer4'] = modified_target_2\n",
    "plot_tensor_images(model(zbatch))\n",
    "model.ablation['layer4'] = None\n",
    "model.replacement['layer4'] = None\n",
    "graph2 = []\n",
    "for layer in instrumented_layers:\n",
    "    features = model.retained[layer]\n",
    "    changed_features = features[0] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]\n",
    "    base_features = features[1] / features[1].view(features.shape[1], -1).mean(1)[:,None,None]\n",
    "    diff = changed_features - base_features\n",
    "    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))\n",
    "                   for c, b in zip(changed_features, base_features)])\n",
    "    norm = wd.mean()\n",
    "    print(layer, norm.item())\n",
    "    graph2.append(norm.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(range(4,16), graph2, label=\"sky\")\n",
    "plt.plot(range(4,16), graph1, label=\"wall\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Some ground truth evaluation.\n",
    "\n",
    "0 = no significant change\n",
    "1 = changed or added a door\n",
    "2 = changed or added a window\n",
    "3 = some other change\n",
    "\n",
    "We don't bother evaluating perturbations on the edges."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt = torch.tensor([\n",
    "  [[0,0,0,0,0,0], [0,0,0,0,0,3], [0,0,2,0,0,0], [0,3,0,0,0,0], [0,0,3,0,2,0], [0,0,0,0,0,0]],\n",
    "  [[0,0,0,0,0,0], [2,0,0,2,2,2], [0,2,2,0,0,0], [0,0,0,2,2,0], [2,0,2,2,2,2], [2,0,0,2,3,0]],\n",
    "  [[0,0,0,0,0,0], [0,0,0,0,2,3], [0,2,2,0,2,0], [0,0,0,0,1,2], [0,0,0,0,2,2], [0,0,0,3,1,0]],\n",
    "  [[0,0,0,3,0,0], [0,2,0,0,3,3], [0,2,0,2,2,0], [0,2,2,0,2,2], [0,0,2,0,2,2], [0,0,0,3,2,0]],\n",
    "  [[0,0,0,1,1,0], [0,0,0,1,1,2], [0,2,0,1,0,0], [0,3,1,1,2,2], [0,0,1,0,0,1], [0,3,2,1,2,2]],\n",
    "  [[0,2,0,1,1,1], [1,2,0,0,1,1], [0,0,1,0,1,0], [0,0,1,1,1,0], [0,2,0,0,1,1], [0,2,1,1,1,0]]\n",
    "])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    \n",
    "    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)\n",
    "    myzbatch[0::2] = zbatch\n",
    "    myzbatch[1::2] = zbatch\n",
    "    ablation = torch.zeros(len(myzbatch))[:,None,None,None]\n",
    "    ablation[1::2] = 1\n",
    "\n",
    "    model.ablation['layer4'] = None\n",
    "    model.replacement['layer4'] = None\n",
    "    model(myzbatch)\n",
    "    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]\n",
    "    door_feature = model.retained['layer4'][0][units][:,5,3]\n",
    "    baseline_target = model.retained['layer4'].clone()\n",
    "    all_results = []\n",
    "    for r in range(1,7):\n",
    "        for c in range(1,7):\n",
    "            modified_target_2 = baseline_target.clone()\n",
    "            modified_target_2[1::2,units,r,c] = door_feature\n",
    "            model.ablation['layer4'] = ablation\n",
    "            model.replacement['layer4'] = modified_target_2\n",
    "            plot_tensor_images(model(myzbatch))\n",
    "            model.ablation['layer4'] = None\n",
    "            model.replacement['layer4'] = None\n",
    "            loc_results = []\n",
    "            for case in range(0,12,2):\n",
    "                graphv = []\n",
    "                for layer in instrumented_layers:\n",
    "                    features = model.retained[layer]\n",
    "                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]\n",
    "                    changed_features = features[case+1] / avg_features\n",
    "                    base_features = features[case] / avg_features\n",
    "                    # diff = changed_features - base_features\n",
    "                    wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))\n",
    "                                   for c, b in zip(changed_features, base_features)])\n",
    "                    norm = wd.mean()\n",
    "                    graphv.append(norm.item())\n",
    "                plt.plot(range(4,16), graphv, label=\"%d p%d,%d\" % (case, r, c))\n",
    "                # plt.plot(range(4,16), graph1, label=\"wall\")\n",
    "                plt.legend()\n",
    "                plt.show()\n",
    "                loc_results.append(graphv)\n",
    "                print('%d location results' % len(loc_results))\n",
    "            all_results.append(loc_results)\n",
    "            print('%d results so far' % len(all_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('norm_was_dist_all_results.json', 'w') as f:\n",
    "    json.dump(all_results, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "effect = []\n",
    "noeffect = []\n",
    "\n",
    "for loc, locg in enumerate(all_results):\n",
    "    r = loc // 6\n",
    "    c = loc % 6\n",
    "    for i, graphv in enumerate(locg):\n",
    "        t = gt[r, c, i]\n",
    "        # if i == 0:\n",
    "        #     continue\n",
    "        if t != 0:\n",
    "            effect.append(graphv)\n",
    "        else:\n",
    "            noeffect.append(graphv)\n",
    "avg_effect = torch.tensor(effect).mean(0).numpy()\n",
    "avg_noeffect = torch.tensor(noeffect).mean(0).numpy()\n",
    "\n",
    "f = plt.figure(figsize=(2.5, 2))\n",
    "plt.plot(range(4,16), avg_effect, label=\"effect\")\n",
    "plt.plot(range(4,16), avg_noeffect, label=\"not\")\n",
    "plt.legend(frameon=False)\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.xlabel('layer after intervention')\n",
    "plt.ylabel('Wasserstein feat dist')\n",
    "plt.show()\n",
    "f.savefig(\"layerwass.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = 'layer5'\n",
    "features = model.retained[layer]\n",
    "changed_features = features[0]\n",
    "base_features = features[1]\n",
    "wd = torch.tensor([wasserstein_distance(c.view(-1), b.view(-1))\n",
    "                   for c, b in zip(changed_features, base_features)])\n",
    "wd.mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    \n",
    "    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)\n",
    "    myzbatch[0::2] = zbatch\n",
    "    myzbatch[1::2] = zbatch\n",
    "    ablation = torch.zeros(len(myzbatch))[:,None,None,None]\n",
    "    ablation[1::2] = 1\n",
    "\n",
    "    model.ablation['layer4'] = None\n",
    "    model.replacement['layer4'] = None\n",
    "    model(myzbatch)\n",
    "    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]\n",
    "    door_feature = model.retained['layer4'][0][units][:,5,3]\n",
    "    baseline_target = model.retained['layer4'].clone()\n",
    "    mean_results = []\n",
    "    for r in range(1,7):\n",
    "        for c in range(1,7):\n",
    "            modified_target_2 = baseline_target.clone()\n",
    "            modified_target_2[1::2,units,r,c] = door_feature\n",
    "            model.ablation['layer4'] = ablation\n",
    "            model.replacement['layer4'] = modified_target_2\n",
    "            plot_tensor_images(model(myzbatch))\n",
    "            model.ablation['layer4'] = None\n",
    "            model.replacement['layer4'] = None\n",
    "            loc_results = []\n",
    "            for case in range(0,12,2):\n",
    "                graphv = []\n",
    "                for layer in instrumented_layers:\n",
    "                    features = model.retained[layer]\n",
    "                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]\n",
    "                    changed_features = features[case+1] / avg_features\n",
    "                    base_features = features[case] / avg_features\n",
    "                    diff = changed_features - base_features\n",
    "                    norm = diff.abs().mean()\n",
    "                    graphv.append(norm.item())\n",
    "                plt.plot(range(4,16), graphv, label=\"%d p%d,%d\" % (case, r, c))\n",
    "                # plt.plot(range(4,16), graph1, label=\"wall\")\n",
    "                plt.legend()\n",
    "                plt.show()\n",
    "                loc_results.append(graphv)\n",
    "                print('%d location results' % len(loc_results))\n",
    "            mean_results.append(loc_results)\n",
    "            print('%d results so far' % len(mean_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "effect = []\n",
    "noeffect = []\n",
    "\n",
    "for loc, locg in enumerate(mean_results):\n",
    "    r = loc // 6\n",
    "    c = loc % 6\n",
    "    for i, graphv in enumerate(locg):\n",
    "        t = gt[r, c, i]\n",
    "        if i == 0:\n",
    "            continue\n",
    "        if t != 0:\n",
    "            effect.append(graphv)\n",
    "        else:\n",
    "            noeffect.append(graphv)\n",
    "avg_effect = torch.tensor(effect)[:].mean(0).numpy()\n",
    "avg_noeffect = torch.tensor(noeffect)[:].mean(0).numpy()\n",
    "\n",
    "f = plt.figure(figsize=(2.5,2))\n",
    "plt.plot(range(4,16), avg_effect, label=\"effect\")\n",
    "plt.plot(range(4,16), avg_noeffect, label=\"not\")\n",
    "plt.gca().spines['right'].set_visible(False)\n",
    "plt.gca().spines['top'].set_visible(False)\n",
    "plt.legend(frameon=False)\n",
    "plt.xlabel('layer after intervention')\n",
    "plt.ylabel('normalized feature diff')\n",
    "plt.show()\n",
    "f.savefig(\"layernorm.pdf\", bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "t_all_results = torch.tensor(all_results).view(6,6,6,12).permute(2,3,0,1)\n",
    "t_mean_results = torch.tensor(mean_results).view(6,6,6,12).permute(2,3,0,1)\n",
    "pad_all_results = torch.zeros(6,12,8,8)\n",
    "pad_all_results[:,:,1:7,1:7] = t_all_results\n",
    "pad_mean_results = torch.zeros(6,12,8,8)\n",
    "pad_mean_results[:,:,1:7,1:7] = t_mean_results\n",
    "\n",
    "for i in range(11,12):\n",
    "    print('wass', i + 4)\n",
    "    plot_max_heatmap(t_all_results[:,i:i+1,:,:], shape=(256,256))\n",
    "    print('mean', i + 4)\n",
    "    plot_max_heatmap(t_mean_results[:,i:i+1,:,:], shape=(256,256))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    \n",
    "    myzbatch = torch.zeros((zbatch.shape[0] * 2,) + zbatch.shape[1:]).to(device)\n",
    "    myzbatch[0::2] = zbatch\n",
    "    myzbatch[1::2] = zbatch\n",
    "    ablation = torch.zeros(len(myzbatch))[:,None,None,None]\n",
    "    ablation[1::2] = 1\n",
    "\n",
    "    model.ablation['layer4'] = None\n",
    "    model.replacement['layer4'] = None\n",
    "    model(myzbatch)\n",
    "    # door_feature = model.retained['layer4'][0][units].view(len(units), -1).max(1)[0]\n",
    "    door_feature = model.retained['layer4'][0][units][:,5,3]\n",
    "    baseline_target = model.retained['layer4'].clone()\n",
    "    full_results = []\n",
    "    for r in range(0,8):\n",
    "        for c in range(0,8):\n",
    "            modified_target_2 = baseline_target.clone()\n",
    "            modified_target_2[1::2,units,r,c] = door_feature\n",
    "            model.ablation['layer4'] = ablation\n",
    "            model.replacement['layer4'] = modified_target_2\n",
    "            plot_tensor_images(model(myzbatch))\n",
    "            model.ablation['layer4'] = None\n",
    "            model.replacement['layer4'] = None\n",
    "            loc_results = []\n",
    "            for case in range(0,12,2):\n",
    "                graphv = []\n",
    "                for layer in instrumented_layers:\n",
    "                    features = model.retained[layer]\n",
    "                    avg_features = features[0::2].permute(1,0,2,3).contiguous().view(features.shape[1], -1).abs().mean(1)[:,None,None]\n",
    "                    changed_features = features[case+1] / avg_features\n",
    "                    base_features = features[case] / avg_features\n",
    "                    diff = changed_features - base_features\n",
    "                    norm = diff.abs().mean()\n",
    "                    graphv.append(norm.item())\n",
    "                plt.plot(range(4,16), graphv, label=\"%d p%d,%d\" % (case, r, c))\n",
    "                # plt.plot(range(4,16), graph1, label=\"wall\")\n",
    "                plt.legend()\n",
    "                plt.show()\n",
    "                loc_results.append(graphv)\n",
    "                print('%d location results' % len(loc_results))\n",
    "            full_results.append(loc_results)\n",
    "            print('%d results so far' % len(mean_results))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "t_full_results = torch.tensor(full_results).view(8,8,6,12).permute(2,3,0,1)\n",
    "\n",
    "for i in range(0,12):\n",
    "    print('full', i + 4)\n",
    "    plot_max_heatmap(t_full_results[:,i:i+1,:,:], shape=(256,256))\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}